import gradio as gr
import os
import subprocess
import argparse

from utils.common_ui import (base_ui, train_param_ui)
from utils.log_util import Logger

log = Logger(debug=True).logger


def ui(**kwargs):
    from theme.winter import Winter
    css = ''
    if os.path.exists('./theme/style.css'):
        with open(os.path.join('./theme/style.css'), 'r', encoding='utf8') as file:
            css += file.read() + '\n'

    with gr.Blocks(title="AI BOY LORA", theme=Winter(), css=css) as interface:
        gr.Markdown('''## <span style='color:brown'>AIBOY 炼丹炉</span> ''')
        (
            base_model_name,
            train_dataset_path,
            output_model_path,
            output_model_name
        ) = base_ui()

        (
            batch_size,
            learning_rate,
            # train_steps,
            epoch,
            save_model_every_n_epochs
        ) = train_param_ui()

        train_button = gr.Button(value="开始训练")
        train_button.click(
            train_lora,
            show_progress=False,
            inputs=[
                base_model_name,
                train_dataset_path,
                output_model_path,
                output_model_name,
                batch_size,
                learning_rate,
                # train_steps,
                epoch,
                save_model_every_n_epochs
            ])

    launch_kwargs = {}
    username = kwargs.get('username')
    password = kwargs.get('password')
    server_port = kwargs.get('server_port', 0)
    share = kwargs.get('share', False)
    server_name = kwargs.get('listen')

    launch_kwargs['server_name'] = server_name
    if username and password:
        launch_kwargs['auth'] = (username, password)
    if server_port > 0:
        launch_kwargs['server_port'] = server_port
    if share:
        launch_kwargs['share'] = share
    log.info(launch_kwargs)
    interface.launch(**launch_kwargs)


def train_lora(
        base_model_name,
        train_dataset_path,
        output_model_path,
        output_model_name,
        batch_size,
        learning_rate,
        epoch,
        save_model_every_n_epochs
):
    cmd_args = f'accelerate launch --num_cpu_threads_per_process=8 "train_network.py"'
    cmd_args += f' --logging_dir="logs"'
    cmd_args += f' --pretrained_model_name_or_path="{base_model_name}"'
    cmd_args += f' --train_data_dir="{train_dataset_path}"'
    cmd_args += f' --output_dir="{output_model_path}"'
    cmd_args += f' --output_name="{output_model_name}"'
    cmd_args += f' --train_batch_size={batch_size}'
    cmd_args += f' --learning_rate={learning_rate}'
    # cmd_args += f' --max_train_steps={int(train_steps)}'
    cmd_args += f' --save_every_n_epochs={save_model_every_n_epochs}'
    cmd_args += f' --max_train_epochs={epoch}'
    cmd_args += f' --resolution=512'
    cmd_args += f' --optimizer_type="AdamW8bit"'
    cmd_args += f' --xformers'
    cmd_args += f' --mixed_precision="fp16"'
    cmd_args += f' --cache_latents'
    cmd_args += f' --gradient_checkpointing'
    cmd_args += f' --network_module=networks.lora'

    log.info(f"训练参数：{cmd_args}")

    subprocess.run(cmd_args)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--listen',
        type=str,
        default='127.0.0.1',
        help='IP to listen on for connections to Gradio',
    )
    parser.add_argument(
        '--username', type=str, default='', help='Username for authentication'
    )
    parser.add_argument(
        '--password', type=str, default='', help='Password for authentication'
    )
    parser.add_argument(
        '--server_port',
        type=int,
        default=0,
        help='Port to run the server listener on',
    )
    parser.add_argument(
        '--share', action='store_true', help='Share the gradio UI'
    )

    args = parser.parse_args()

    ui(
        username=args.username,
        password=args.password,
        server_port=args.server_port,
        share=args.share,
        listen=args.listen,
    )
